(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載)
之前曾經提到,DeviceArray 是 JAX 自行定義的陣列類別,定義在 jax.numpy.DeviceArray,它的角色等同於 Numpy 中的 ndarray,現在就讓我們來更進一步的認識這個類別 [9.1]。
我們通常不需要直接宣告一個 DeviceArray 物件 (object, 或者也可以稱之為案例 instance),許多 jax.numpy 的 API 可以協助我們產生所需的 DeviceArray。例如:
jax.numpy.append(arr, values, axis=None)
jax.numpy.arange(start, stop=None, step=None, dtype=None)
jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)
jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0)
jax.numpy.ones(shape, dtype=None)
jax.numpy.zeros(shape, dtype=None)
這些 API 在 Numpy 皆有對應的函式,語法和用法幾乎完全一樣,差別在於 Numpy 回傳的是 ndarray ,而 JAX 回傳 DeviceArray 。值得一提的是 jax.numpy.array() ,我們可以用它來直接將 ndarray 轉換為 DeviceArray。
# create numpy ndarray
x_np = np.arange(10)
print(f'type of x_np: {type(x_np)}')
print(f'value of x_np:{x_np}')
# convert to DeviceArray
x_jnp = jnp.array(x_np)
print(f'type of x_jnp: {type(x_jnp)}')
print(f'value of x_jnp: {x_jnp}')
output:
type of x_np: <class 'numpy.ndarray'>
value of x_np:[0 1 2 3 4 5 6 7 8 9]
type of x_jnp: <class 'jaxlib.xla_extension.DeviceArray'>
value of x_jnp: [0 1 2 3 4 5 6 7 8 9]
當我們使用 python 的 type() 來檢查 DeviceArray 型別的變數時,回傳的是:
<class 'jaxlib.xla_extension.DeviceArray'>
這是 DeviceArray 在 JAX Python 庫 “jaxlib” 實際的位置,不過為了方便,JAX 提供了別名 (alias) jax.numpy.DeviceArray ,讓大家使用。
# jax.numpy.DeviceArray is an alias of jaxlib.xla_extension.DeviceArray
isinstance(x_jnp, jnp.DeviceArray)
output:
True
不可變 (immutability) 是 jax.numpy 和 Numpy 最大的差異之一!初學 JAX 的讀者,要非常注意這個部份。
在 Numpy 中我們可習慣使用以下的程式段來更改陣列元素的值:
# numpy is mutible
x = np.arange(10)
print(f'Before assignment : {x}')
x[0] = 10
print(f'After assignment : {x}')
output:
Before assignment : [0 1 2 3 4 5 6 7 8 9]
After assignment : [10 1 2 3 4 5 6 7 8 9]
然而同樣的方式,在 JAX 的 DeviceArray 上則會造成執行時錯誤。
# JAX/DeviceArray is immutible
x_jnp = jnp.arange(10)
print(f'Before assignment : {x_jnp}')
x_jnp[0] = 10
output:
Before assignment : [0 1 2 3 4 5 6 7 8 9]…
*TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead ofx[idx] = y
, usex = x.at[idx].set(y)
or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html *
因為 DeviceArray 是不可變的,所以上面程式的變數 x_jnp 在被創造出來之後,其值就不可以改變。 JAX 提供的折衷方法,是利用 x.at[idx].set(y) 這種方式:
# For updating individual elements, JAX provides an indexed update syntax that returns
# an updated copy:
x = jnp.arange(10)
y = x.at[0].set(10)
print(f'x: {x}')
print(f'y: {y}')
output:
x: [0 1 2 3 4 5 6 7 8 9]
y: [10 1 2 3 4 5 6 7 8 9]
要注意 x.at[0].set(10) 這個敍述式 (expression) 並不會更改 x 的值,它是將 x 複製一份,在副本上修改索引 0 的值為 10。所以下面這個常用的敍述:
x = x.at[0].set(10)
在執行之後,變數 x 其實已經參考到不同的記憶體位址了。我們可以檢驗看看:
# .at[].set() copy the DeviceArray.
x2 = jnp.arange(5.0)
print(f'Before at.set: {id(x2)}')
x2 = x2.at[0].set(9.9)
print(f'After at.set : {id(x2)}')
print(x2)
output:
Before at.set: 139712279962800
After at.set : 139712279963568
[9.9 1. 2. 3. 4. ]
DeviceArray.at[ ] 也可以指定索引範圍:
x2.at[2:4].set(88.88)
output:
DeviceArray([ 9.9 , 1. , 88.88, 88.88, 4. ], dtype=float32)
除了 set () 之外,DeviceArray.at[ ] 還提供了其他的運算,下表是 JAX 官方文件 [9.1] 所列出的操作,以及其對應的 Numpy (In-place) 語法,供大家參考:
用 at[].add() 舉個例子:
jax_array = jnp.ones((5, 6))
print("original array:")
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
output:
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
註:
[9.1] 可以參考 JAX 官方文件 jax.numpy package。